import collections
import copy
import math
import os
import re

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import AdaptiveAvgPool2D
from paddle.nn import AvgPool2D
from paddle.nn import BatchNorm
from paddle.nn import Conv2D
from paddle.nn import Dropout
from paddle.nn import Linear
from paddle.nn import MaxPool2D

MODEL_URLS = {
    "EfficientNetB0_small":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB0_small_pretrained.pdparams",
    "EfficientNetB0": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB0_pretrained.pdparams",
    "EfficientNetB1": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB1_pretrained.pdparams",
    "EfficientNetB2": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB2_pretrained.pdparams",
    "EfficientNetB3": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB3_pretrained.pdparams",
    "EfficientNetB4": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB4_pretrained.pdparams",
    "EfficientNetB5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB5_pretrained.pdparams",
    "EfficientNetB6": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB6_pretrained.pdparams",
    "EfficientNetB7": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB7_pretrained.pdparams",
}

__all__ = list(MODEL_URLS.keys())

GlobalParams = collections.namedtuple('GlobalParams', [
    'batch_norm_momentum',
    'batch_norm_epsilon',
    'dropout_rate',
    'num_classes',
    'width_coefficient',
    'depth_coefficient',
    'depth_divisor',
    'min_depth',
    'drop_connect_rate',
])

BlockArgs = collections.namedtuple(
    'BlockArgs',
    ['kernel_size', 'num_repeat', 'input_filters', 'output_filters', 'expand_ratio', 'id_skip', 'stride', 'se_ratio'])

GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)


def load_dygraph_pretrain(model, path=None):
    if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
        raise ValueError("Model pretrain path {} does not "
                         "exists.".format(path))
    param_state_dict = paddle.load(path + ".pdparams")
    model.set_dict(param_state_dict)
    return


def efficientnet_params(model_name):
    """ Map EfficientNet model name to parameter coefficients. """
    params_dict = {
        # Coefficients:   width,depth,resolution,dropout
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
        'efficientnet-b2': (1.1, 1.2, 260, 0.3),
        'efficientnet-b3': (1.2, 1.4, 300, 0.3),
        'efficientnet-b4': (1.4, 1.8, 380, 0.4),
        'efficientnet-b5': (1.6, 2.2, 456, 0.4),
        'efficientnet-b6': (1.8, 2.6, 528, 0.5),
        'efficientnet-b7': (2.0, 3.1, 600, 0.5),
    }
    return params_dict[model_name]


def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, drop_connect_rate=0.2):
    """ Get block arguments according to parameter and coefficients. """
    blocks_args = [
        'r1_k3_s11_e1_i32_o16_se0.25',
        'r2_k3_s22_e6_i16_o24_se0.25',
        'r2_k5_s22_e6_i24_o40_se0.25',
        'r3_k3_s22_e6_i40_o80_se0.25',
        'r3_k5_s11_e6_i80_o112_se0.25',
        'r4_k5_s22_e6_i112_o192_se0.25',
        'r1_k3_s11_e6_i192_o320_se0.25',
    ]
    blocks_args = BlockDecoder.decode(blocks_args)

    global_params = GlobalParams(batch_norm_momentum=0.99,
                                 batch_norm_epsilon=1e-3,
                                 dropout_rate=dropout_rate,
                                 drop_connect_rate=drop_connect_rate,
                                 num_classes=1000,
                                 width_coefficient=width_coefficient,
                                 depth_coefficient=depth_coefficient,
                                 depth_divisor=8,
                                 min_depth=None)
    return blocks_args, global_params


def get_model_params(model_name, override_params):
    """ Get the block args and global params for a given model """
    if model_name.startswith('efficientnet'):
        w, d, _, p = efficientnet_params(model_name)
        blocks_args, global_params = efficientnet(width_coefficient=w, depth_coefficient=d, dropout_rate=p)
    else:
        raise NotImplementedError('model name is not pre-defined: %s' % model_name)
    if override_params:
        global_params = global_params._replace(**override_params)
    return blocks_args, global_params


def round_filters(filters, global_params):
    """ Calculate and round number of filters based on depth multiplier. """
    multiplier = global_params.width_coefficient
    if not multiplier:
        return filters
    divisor = global_params.depth_divisor
    min_depth = global_params.min_depth
    filters *= multiplier
    min_depth = min_depth or divisor
    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
    if new_filters < 0.9 * filters:  # prevent rounding by more than 10%
        new_filters += divisor
    return int(new_filters)


def round_repeats(repeats, global_params):
    """ Round number of filters based on depth multiplier. """
    multiplier = global_params.depth_coefficient
    if not multiplier:
        return repeats
    return int(math.ceil(multiplier * repeats))


class BlockDecoder(object):
    """
    Block Decoder, straight from the official TensorFlow repository.
    """

    @staticmethod
    def _decode_block_string(block_string):
        """ Gets a block through a string notation of arguments. """
        assert isinstance(block_string, str)

        ops = block_string.split('_')
        options = {}
        for op in ops:
            splits = re.split(r'(\d.*)', op)
            if len(splits) >= 2:
                key, value = splits[:2]
                options[key] = value

        # Check stride
        cond_1 = ('s' in options and len(options['s']) == 1)
        cond_2 = ((len(options['s']) == 2) and (options['s'][0] == options['s'][1]))
        assert (cond_1 or cond_2)

        return BlockArgs(kernel_size=int(options['k']),
                         num_repeat=int(options['r']),
                         input_filters=int(options['i']),
                         output_filters=int(options['o']),
                         expand_ratio=int(options['e']),
                         id_skip=('noskip' not in block_string),
                         se_ratio=float(options['se']) if 'se' in options else None,
                         stride=[int(options['s'][0])])

    @staticmethod
    def _encode_block_string(block):
        """Encodes a block to a string."""
        args = [
            'r%d' % block.num_repeat,
            'k%d' % block.kernel_size,
            's%d%d' % (block.strides[0], block.strides[1]),
            'e%s' % block.expand_ratio,
            'i%d' % block.input_filters,
            'o%d' % block.output_filters
        ]
        if 0 < block.se_ratio <= 1:
            args.append('se%s' % block.se_ratio)
        if block.id_skip is False:
            args.append('noskip')
        return '_'.join(args)

    @staticmethod
    def decode(string_list):
        """
        Decode a list of string notations to specify blocks in the network.
        string_list: list of strings, each string is a notation of block
        return
            list of BlockArgs namedtuples of block args
        """
        assert isinstance(string_list, list)
        blocks_args = []
        for block_string in string_list:
            blocks_args.append(BlockDecoder._decode_block_string(block_string))
        return blocks_args

    @staticmethod
    def encode(blocks_args):
        """
        Encodes a list of BlockArgs to a list of strings.
        :param blocks_args: a list of BlockArgs namedtuples of block args
        :return: a list of strings, each string is a notation of block
        """
        block_strings = []
        for block in blocks_args:
            block_strings.append(BlockDecoder._encode_block_string(block))
        return block_strings


def initial_type(name, use_bias=False):
    param_attr = ParamAttr(name=name + "_weights")
    if use_bias:
        bias_attr = ParamAttr(name=name + "_offset")
    else:
        bias_attr = False
    return param_attr, bias_attr


def init_batch_norm_layer(name="batch_norm"):
    param_attr = ParamAttr(name=name + "_scale")
    bias_attr = ParamAttr(name=name + "_offset")
    return param_attr, bias_attr


def init_fc_layer(name="fc"):
    param_attr = ParamAttr(name=name + "_weights")
    bias_attr = ParamAttr(name=name + "_offset")
    return param_attr, bias_attr


def cal_padding(img_size, stride, filter_size, dilation=1):
    """Calculate padding size."""
    # out_size = max(filter_size - stride, 0)
    if img_size % stride == 0:
        out_size = max(filter_size - stride, 0)
    else:
        out_size = max(filter_size - (img_size % stride), 0)
    return out_size // 2, out_size - out_size // 2


# inp_shape = {
#     "b0_small": [224, 112, 112, 56, 28, 14, 14, 7],
#     "b0": [224, 112, 112, 56, 28, 14, 14, 7],
#     "b1": [240, 120, 120, 60, 30, 15, 15, 8],
#     "b2": [260, 130, 130, 65, 33, 17, 17, 9],
#     "b3": [300, 150, 150, 75, 38, 19, 19, 10],
#     "b4": [380, 190, 190, 95, 48, 24, 24, 12],
#     "b5": [456, 228, 228, 114, 57, 29, 29, 15],
#     "b6": [528, 264, 264, 132, 66, 33, 33, 17],
#     "b7": [600, 300, 300, 150, 75, 38, 38, 19]
# }
inp_shape = {
    "b0_small": [224, 112, 112, 56, 28, 14, 14, 7],
    "b0": [224, 112, 112, 56, 28, 14, 14, 7],
    "b1": [240, 120, 120, 60, 30, 15, 15, 8],
    "b2": [260, 130, 130, 65, 33, 17, 17, 9],
    "b3": [300, 150, 150, 75, 38, 19, 19, 10],
    "b4": [380, 190, 190, 95, 48, 24, 24, 12],
    "b5": [256, 128, 128, 64, 32, 16, 16, 8],
    "b6": [528, 264, 264, 132, 66, 33, 33, 17],
    "b7": [600, 300, 300, 150, 75, 38, 38, 19]
}


def _drop_connect(inputs, prob, is_test):
    if is_test:
        output = inputs
    else:
        keep_prob = 1.0 - prob
        inputs_shape = paddle.shape(inputs)
        random_tensor = keep_prob + paddle.rand(shape=[inputs_shape[0], 1, 1, 1])
        binary_tensor = paddle.floor(random_tensor)
        output = paddle.multiply(inputs, binary_tensor) / keep_prob
    return output


class Conv2ds(nn.Layer):

    def __init__(self,
                 input_channels,
                 output_channels,
                 filter_size,
                 stride=1,
                 padding=0,
                 groups=None,
                 name="conv2d",
                 act=None,
                 use_bias=False,
                 padding_type=None,
                 model_name=None,
                 cur_stage=None):
        super(Conv2ds, self).__init__()
        assert act in [None, "swish", "sigmoid"]
        self.act = act

        param_attr, bias_attr = initial_type(name=name, use_bias=use_bias)

        self.padding_type = padding_type
        self.stride = stride
        self.filter_size = filter_size

        def get_padding(filter_size, stride=1, dilation=1):
            padding = ((stride - 1) + dilation * (filter_size - 1)) // 2
            return padding

        inps = 1 if model_name == None and cur_stage == None else inp_shape[model_name][cur_stage]
        self.need_crop = False
        if padding_type == "SAME":
            top_padding, bottom_padding = cal_padding(inps, stride, filter_size)
            left_padding, right_padding = cal_padding(inps, stride, filter_size)
            height_padding = bottom_padding
            width_padding = right_padding
            if top_padding != bottom_padding or left_padding != right_padding:
                height_padding = top_padding + stride
                width_padding = left_padding + stride
                self.need_crop = True
            padding = [height_padding, width_padding]
        elif padding_type == "VALID":
            height_padding = 0
            width_padding = 0
            padding = [height_padding, width_padding]
        elif padding_type == "DYNAMIC":
            padding = get_padding(filter_size, stride)
        else:
            padding = padding_type

        groups = 1 if groups is None else groups
        print('------')
        print(input_channels)
        print(output_channels)
        print(filter_size)
        print('------')
        self._conv = Conv2D(
            input_channels,
            output_channels,
            filter_size,
            groups=groups,
            stride=stride,
            # act=act,
            padding=padding,
            weight_attr=param_attr,
            bias_attr=bias_attr)

    def forward(self, inputs):
        x = self._conv(inputs)
        if self.act == "swish":
            x = F.swish(x)
        elif self.act == "sigmoid":
            x = F.sigmoid(x)

        if self.need_crop:
            x = x[:, :, 1:, 1:]
        return x


class ConvBNLayer(nn.Layer):

    def __init__(self,
                 input_channels,
                 filter_size,
                 output_channels,
                 stride=1,
                 num_groups=1,
                 padding_type="SAME",
                 conv_act=None,
                 bn_act="swish",
                 use_bn=True,
                 use_bias=False,
                 name=None,
                 conv_name=None,
                 bn_name=None,
                 model_name=None,
                 cur_stage=None):
        super(ConvBNLayer, self).__init__()
        self._conv = Conv2ds(input_channels=input_channels,
                             output_channels=output_channels,
                             filter_size=filter_size,
                             stride=stride,
                             groups=num_groups,
                             act=conv_act,
                             padding_type=padding_type,
                             name=conv_name,
                             use_bias=use_bias,
                             model_name=model_name,
                             cur_stage=cur_stage)
        self.use_bn = use_bn
        if use_bn is True:
            bn_name = name + bn_name
            param_attr, bias_attr = init_batch_norm_layer(bn_name)
            self._bn = BatchNorm(num_channels=output_channels,
                                 act=bn_act,
                                 momentum=0.99,
                                 epsilon=0.001,
                                 moving_mean_name=bn_name + "_mean",
                                 moving_variance_name=bn_name + "_variance",
                                 param_attr=param_attr,
                                 bias_attr=bias_attr)

    def forward(self, inputs):
        print('ConvBNLayer:')
        if self.use_bn:
            x = self._conv(inputs)
            x = self._bn(x)
            print(x.shape)
            print('-------')
            return x
        else:
            return self._conv(inputs)


class ExpandConvNorm(nn.Layer):

    def __init__(self, input_channels, block_args, padding_type, name=None, model_name=None, cur_stage=None):
        super(ExpandConvNorm, self).__init__()

        self.oup = block_args.input_filters * block_args.expand_ratio
        self.expand_ratio = block_args.expand_ratio

        if self.expand_ratio != 1:
            self._conv = ConvBNLayer(input_channels,
                                     1,
                                     self.oup,
                                     bn_act=None,
                                     padding_type=padding_type,
                                     name=name,
                                     conv_name=name + "_expand_conv",
                                     bn_name="_bn0",
                                     model_name=model_name,
                                     cur_stage=cur_stage)

    def forward(self, inputs):
        if self.expand_ratio != 1:
            return self._conv(inputs)
        else:
            return inputs


class DepthwiseConvNorm(nn.Layer):

    def __init__(self, input_channels, block_args, padding_type, name=None, model_name=None, cur_stage=None):
        super(DepthwiseConvNorm, self).__init__()

        self.k = block_args.kernel_size
        self.s = block_args.stride
        if isinstance(self.s, list) or isinstance(self.s, tuple):
            self.s = self.s[0]
        oup = block_args.input_filters * block_args.expand_ratio
        self._conv = ConvBNLayer(input_channels,
                                 self.k,
                                 oup,
                                 self.s,
                                 num_groups=input_channels,
                                 bn_act=None,
                                 padding_type=padding_type,
                                 name=name,
                                 conv_name=name + "_depthwise_conv",
                                 bn_name="_bn1",
                                 model_name=model_name,
                                 cur_stage=cur_stage)

    def forward(self, inputs):
        return self._conv(inputs)


class ProjectConvNorm(nn.Layer):

    def __init__(self, input_channels, block_args, padding_type, name=None, model_name=None, cur_stage=None):
        super(ProjectConvNorm, self).__init__()

        final_oup = block_args.output_filters

        self._conv = ConvBNLayer(input_channels,
                                 1,
                                 final_oup,
                                 bn_act=None,
                                 padding_type=padding_type,
                                 name=name,
                                 conv_name=name + "_project_conv",
                                 bn_name="_bn2",
                                 model_name=model_name,
                                 cur_stage=cur_stage)

    def forward(self, inputs):
        return self._conv(inputs)


class SEBlock(nn.Layer):

    def __init__(self,
                 input_channels,
                 num_squeezed_channels,
                 oup,
                 padding_type,
                 name=None,
                 model_name=None,
                 cur_stage=None):
        super(SEBlock, self).__init__()

        self._pool = AdaptiveAvgPool2D(1)
        self._conv1 = Conv2ds(input_channels,
                              num_squeezed_channels,
                              1,
                              use_bias=True,
                              padding_type=padding_type,
                              act="swish",
                              name=name + "_se_reduce")

        self._conv2 = Conv2ds(
            num_squeezed_channels,
            oup,
            1,
            # act="sigmoid",
            act=None,
            use_bias=True,
            padding_type=padding_type,
            name=name + "_se_expand")

    def forward(self, inputs):
        x = self._pool(inputs)
        x = self._conv1(x)
        x = self._conv2(x)
        out = paddle.multiply(inputs, F.sigmoid(x))
        return out


class MbConvBlock(nn.Layer):

    def __init__(self,
                 input_channels,
                 block_args,
                 padding_type,
                 use_se,
                 name=None,
                 drop_connect_rate=None,
                 model_name=None,
                 cur_stage=None):
        super(MbConvBlock, self).__init__()

        oup = block_args.input_filters * block_args.expand_ratio
        self.block_args = block_args
        self.has_se = use_se and (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)
        self.id_skip = block_args.id_skip
        self.expand_ratio = block_args.expand_ratio
        self.drop_connect_rate = drop_connect_rate

        if self.expand_ratio != 1:
            self._ecn = ExpandConvNorm(input_channels,
                                       block_args,
                                       padding_type=padding_type,
                                       name=name,
                                       model_name=model_name,
                                       cur_stage=cur_stage)

        self._dcn = DepthwiseConvNorm(input_channels * block_args.expand_ratio,
                                      block_args,
                                      padding_type=padding_type,
                                      name=name,
                                      model_name=model_name,
                                      cur_stage=cur_stage)

        if self.has_se:
            num_squeezed_channels = max(1, int(block_args.input_filters * block_args.se_ratio))
            self._se = SEBlock(input_channels * block_args.expand_ratio,
                               num_squeezed_channels,
                               oup,
                               padding_type=padding_type,
                               name=name,
                               model_name=model_name,
                               cur_stage=cur_stage)

        self._pcn = ProjectConvNorm(input_channels * block_args.expand_ratio,
                                    block_args,
                                    padding_type=padding_type,
                                    name=name,
                                    model_name=model_name,
                                    cur_stage=cur_stage)

    def forward(self, inputs):
        x = inputs
        if self.expand_ratio != 1:
            x = self._ecn(x)
            x = F.swish(x)
        x = self._dcn(x)
        x = F.swish(x)
        if self.has_se:
            x = self._se(x)
        x = self._pcn(x)
        if self.id_skip and \
                self.block_args.stride == 1 and \
                self.block_args.input_filters == self.block_args.output_filters:
            if self.drop_connect_rate:
                x = _drop_connect(x, self.drop_connect_rate, not self.training)
            x = paddle.add(x, inputs)
        return x


class ConvStemNorm(nn.Layer):

    def __init__(self, input_channels, padding_type, _global_params, name=None, model_name=None, cur_stage=None):
        super(ConvStemNorm, self).__init__()

        output_channels = round_filters(32, _global_params)

        self._conv = ConvBNLayer(input_channels,
                                 filter_size=3,
                                 output_channels=output_channels,
                                 stride=2,
                                 bn_act=None,
                                 padding_type=padding_type,
                                 name="",
                                 conv_name="_conv_stem",
                                 bn_name="_bn0",
                                 model_name=model_name,
                                 cur_stage=cur_stage)

    def forward(self, inputs):
        return self._conv(inputs)


class ExtractFeatures(nn.Layer):

    def __init__(self, input_channels, _block_args, _global_params, padding_type, use_se, model_name=None):
        super(ExtractFeatures, self).__init__()

        self._global_params = _global_params

        self._conv_stem = ConvStemNorm(input_channels,
                                       padding_type=padding_type,
                                       _global_params=_global_params,
                                       model_name=model_name,
                                       cur_stage=0)

        self.block_args_copy = copy.deepcopy(_block_args)
        idx = 0
        block_size = 0
        for block_arg in self.block_args_copy:
            block_arg = block_arg._replace(input_filters=round_filters(block_arg.input_filters, _global_params),
                                           output_filters=round_filters(block_arg.output_filters, _global_params),
                                           num_repeat=round_repeats(block_arg.num_repeat, _global_params))
            block_size += 1
            for _ in range(block_arg.num_repeat - 1):
                block_size += 1

        self.conv_seq = []
        cur_stage = 1
        for block_args in _block_args:
            block_args = block_args._replace(input_filters=round_filters(block_args.input_filters, _global_params),
                                             output_filters=round_filters(block_args.output_filters, _global_params),
                                             num_repeat=round_repeats(block_args.num_repeat, _global_params))
            drop_connect_rate = self._global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / block_size
            _mc_block = self.add_sublayer(
                "_blocks." + str(idx) + ".",
                MbConvBlock(block_args.input_filters,
                            block_args=block_args,
                            padding_type=padding_type,
                            use_se=use_se,
                            name="_blocks." + str(idx) + ".",
                            drop_connect_rate=drop_connect_rate,
                            model_name=model_name,
                            cur_stage=cur_stage))
            self.conv_seq.append(_mc_block)
            idx += 1
            if block_args.num_repeat > 1:
                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
            for _ in range(block_args.num_repeat - 1):
                drop_connect_rate = self._global_params.drop_connect_rate
                if drop_connect_rate:
                    drop_connect_rate *= float(idx) / block_size
                _mc_block = self.add_sublayer(
                    "block." + str(idx) + ".",
                    MbConvBlock(block_args.input_filters,
                                block_args,
                                padding_type=padding_type,
                                use_se=use_se,
                                name="_blocks." + str(idx) + ".",
                                drop_connect_rate=drop_connect_rate,
                                model_name=model_name,
                                cur_stage=cur_stage))
                self.conv_seq.append(_mc_block)
                idx += 1
            cur_stage += 1

    def forward(self, inputs):
        print('ExtractFeatures:')
        print(inputs.shape)
        x = self._conv_stem(inputs)
        print(x.shape)
        print('******')
        x = F.swish(x)
        for _mc_block in self.conv_seq:
            x = _mc_block(x)
        return x


class EfficientNet(nn.Layer):

    def __init__(self, name="b0", padding_type="SAME", override_params=None, use_se=True, class_num=768):
        super(EfficientNet, self).__init__()

        model_name = 'efficientnet-' + name
        self.name = name
        self._block_args, self._global_params = get_model_params(model_name, override_params)
        self.padding_type = padding_type
        self.use_se = use_se

        self._ef = ExtractFeatures(3,
                                   self._block_args,
                                   self._global_params,
                                   self.padding_type,
                                   self.use_se,
                                   model_name=self.name)

        output_channels = round_filters(1280, self._global_params)

        if name == "b0_small" or name == "b0" or name == "b1":
            oup = 320
        elif name == "b2":
            oup = 352
        elif name == "b3":
            oup = 384
        elif name == "b4":
            oup = 448
        elif name == "b5":
            oup = 512
        elif name == "b6":
            oup = 576
        elif name == "b7":
            oup = 640
        self._conv = ConvBNLayer(oup,
                                 1,
                                 output_channels,
                                 bn_act="swish",
                                 padding_type=self.padding_type,
                                 name="",
                                 conv_name="_conv_head",
                                 bn_name="_bn1",
                                 model_name=self.name,
                                 cur_stage=7)
        self._pool = AdaptiveAvgPool2D(1)

        if self._global_params.dropout_rate:
            self._drop = Dropout(p=self._global_params.dropout_rate, mode="upscale_in_train")

        # param_attr, bias_attr = init_fc_layer("_fc")
        self._fc = Linear(output_channels,
                          class_num,
                          weight_attr=paddle.ParamAttr(name='image_trans_w'),
                          bias_attr=paddle.ParamAttr(name='image_trans_b'))

    def forward(self, inputs):
        x = self._ef(inputs)
        print('EfficientNet:')
        print(x.shape)
        print(self._conv)
        x = self._conv(x)
        x = self._pool(x)
        if self._global_params.dropout_rate:
            x = self._drop(x)
        x = paddle.squeeze(x, axis=[2, 3])
        x = self._fc(x)
        x = F.tanh(x)
        return x


def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError("pretrained type is not available. Please use `string` type.")


def EfficientNetB0_small(padding_type='DYNAMIC',
                         override_params=None,
                         use_se=False,
                         pretrained=False,
                         use_ssld=False,
                         **kwargs):
    model = EfficientNet(name='b0', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB0_small"])
    return model


def EfficientNetB0(padding_type='SAME', override_params=None, use_se=True, pretrained=False, use_ssld=False, **kwargs):
    model = EfficientNet(name='b0', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB0"])
    return model


def EfficientNetB1(padding_type='SAME', override_params=None, use_se=True, pretrained=False, use_ssld=False, **kwargs):
    model = EfficientNet(name='b1', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB1"])
    return model


def EfficientNetB2(padding_type='SAME', override_params=None, use_se=True, pretrained=False, use_ssld=False, **kwargs):
    model = EfficientNet(name='b2', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB2"])
    return model


def EfficientNetB3(padding_type='SAME', override_params=None, use_se=True, pretrained=False, use_ssld=False, **kwargs):
    model = EfficientNet(name='b3', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB3"])
    return model


def EfficientNetB4(padding_type='SAME', override_params=None, use_se=True, pretrained=False, use_ssld=False, **kwargs):
    model = EfficientNet(name='b4', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB4"])
    return model


def EfficientNetB5(padding_type='SAME', override_params=None, use_se=True, pretrained=False, use_ssld=False, **kwargs):
    model = EfficientNet(name='b5', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB5"])
    return model


def EfficientNetB6(padding_type='SAME', override_params=None, use_se=True, pretrained=False, use_ssld=False, **kwargs):
    model = EfficientNet(name='b6', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB6"])
    return model


def EfficientNetB7(padding_type='SAME', override_params=None, use_se=True, pretrained=False, use_ssld=False, **kwargs):
    model = EfficientNet(name='b7', padding_type=padding_type, override_params=override_params, use_se=use_se, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB7"])
    return model
